import torch
import torchvision

vgg16 = torchvision.models.vgg16()

# 保存方式1，模型参数+模型结构
torch.save(vgg16, "model_method1.pth")
# 自己创建的模型加载时需要注意要把自己的模型导入进来，否则会找不到对应的模型
# model = torch.load("model_method1.pth")

# 保存方式2，模型参数（官方推荐方式）
# torch.save(vgg16.state_dict(), "model_method2.pth")
# # 加载时需要创建模型，让模型加载参数
# state_dict = torch.load("model_method2.pth")
# vgg16_2 = torchvision.models.vgg16()
# vgg16_2.load_state_dict(state_dict)